// We are not parallelizing over the channel dimension
// because during the backward pass, to compute gradient,
// each pooled channel will write back to multiple gradient
// channels, causing complicated race conditions.
#define CHANN_MAX_POOL_BOUNDS_AND_INDEX \
  int n = threadIdx.x + blockIdx.x * blockDim.x; \
  int s = threadIdx.z + blockIdx.z * blockDim.z; \
  if (s >= spatial_dim || n >= num) \
    return; \
  int fea_dim = spatial_dim * channels; \
  int fea_dim_out = spatial_dim * pooled_chann; \
  input = input + fea_dim * n; \
  output = output + fea_dim_out * n; \
  mask = mask + fea_dim_out * n

template <typename T>
__device__ void max_channel_pooling_forward(const T *input, T *output, size_t *mask, 
    int spatial_dim, int channels, int num, 
    int pooled_chann, int kernel, int stride, int pre_pad) {

  CHANN_MAX_POOL_BOUNDS_AND_INDEX;
  for (int pc = 0; pc < pooled_chann; ++pc) {
    int cstart = pc*stride - pre_pad;
    int cend = min(cstart + kernel, channels);
    cstart = max(cstart, 0);

    T maxval = input[cstart*spatial_dim + s];
    size_t maxidx = cstart;
    T val; 
    for (int c = cstart+1; c < cend; ++c) {
      val = input[c*spatial_dim + s];
      if (val > maxval) {
        maxval = val;
        maxidx = c;
      }
    }
    output[pc*spatial_dim + s] = maxval;
    mask[pc*spatial_dim + s] = maxidx;
  }
}

// Please remember to erase the whole input to zero before calling this
template <typename T>
__device__ void max_channel_pooling_backward(T *input, const T *output, const size_t *mask,
    int spatial_dim, int channels, int num, 
    int pooled_chann, int kernel, int stride, int pre_pad) {

  CHANN_MAX_POOL_BOUNDS_AND_INDEX;
  for (int pc = 0; pc < pooled_chann; ++pc) {
    int cstart = pc*stride - pre_pad;
    int cend = min(cstart + kernel, channels);
    cstart = max(cstart, 0);

    input[mask[pc*spatial_dim+s]*spatial_dim + s] +=
      output[pc*spatial_dim+s];
  }
}

extern "C" {
  __global__ void max_channel_pooling_forward_float(const float *input, float *output, size_t *mask,
      int spatial_dim, int channels, int num, 
      int pooled_chann, int kernel, int stride, int pre_pad) {
    max_channel_pooling_forward(input, output, mask, spatial_dim, channels, num,
        pooled_chann, kernel, stride, pre_pad);
  }
  __global__ void max_channel_pooling_forward_double(const double *input, double *output, size_t *mask,
      int spatial_dim, int channels, int num, 
      int pooled_chann, int kernel, int stride, int pre_pad) {
    max_channel_pooling_forward(input, output, mask, spatial_dim, channels, num,
        pooled_chann, kernel, stride, pre_pad);
  }
  __global__ void max_channel_pooling_backward_float(float *input, const float *output, const size_t *mask,
      int spatial_dim, int channels, int num, 
      int pooled_chann, int kernel, int stride, int pre_pad) {
    max_channel_pooling_backward(input, output, mask, spatial_dim, channels, num,
        pooled_chann, kernel, stride, pre_pad);
  }
  __global__ void max_channel_pooling_backward_double(double *input, const double *output, const size_t *mask,
      int spatial_dim, int channels, int num, 
      int pooled_chann, int kernel, int stride, int pre_pad) {
    max_channel_pooling_backward(input, output, mask, spatial_dim, channels, num,
        pooled_chann, kernel, stride, pre_pad);
  }
}

// vim: ft=cuda
